{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images = tf.expand_dims(train_images,-1)\n",
    "test_images = tf.expand_dims(test_images,-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_images = tf.cast(train_images/255,tf.float32)\n",
    "test_images = tf.cast(test_images/255,tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = tf.cast(train_labels,tf.int64)\n",
    "test_labels = tf.cast(test_labels,tf.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((train_images,train_labels))\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_images,test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在使用fit时repeat（）需要带参数，自定义循环时repeat（）不需要带参数\n",
    "dataset = dataset.repeat().shuffle(60000).batch(128)\n",
    "test_dataset = test_dataset.repeat().batch(128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Conv2D(16,[3,3],activation='relu',input_shape=(None,None,1)),\n",
    "    tf.keras.layers.Conv2D(32,[3,3],activation='relu'),\n",
    "    tf.keras.layers.GlobalMaxPooling2D(),\n",
    "    tf.keras.layers.Dense(10,activation='softmax')\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam',\n",
    "             loss='sparse_categorical_crossentropy',\n",
    "             metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tensorboard可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为目录名添加时间\n",
    "import os\n",
    "log_dir = os.path.join('logs',datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第一个参数：事件文件存放地址  \n",
    "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir,histogram_freq=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个文件编写器\n",
    "file_writer = tf.summary.create_file_writer(log_dir + '/lr')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置为默认的文件编写器\n",
    "file_writer.set_as_default()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lr_sche(epoch):\n",
    "    learning_rate = 0.2\n",
    "    if epoch > 5:\n",
    "        learning_rate = 0.02\n",
    "    if epoch > 10:\n",
    "        learning_rate = 0.01\n",
    "    if epoch > 20:\n",
    "        learning_rate = 0.005\n",
    "    #将变化写入磁盘\n",
    "    tf.summary.scalar('learning_rate',data = learning_rate,step = epoch)\n",
    "    return learning_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_sche)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/25\n",
      "  2/468 [..............................] - ETA: 1:48 - loss: 2.2985 - accuracy: 0.1055WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0880s vs `on_train_batch_end` time: 0.3763s). Check your callbacks.\n",
      "468/468 [==============================] - 47s 101ms/step - loss: 2.3133 - accuracy: 0.1035 - val_loss: 2.3125 - val_accuracy: 0.1135 35s - loss: 2.3149 - acc - ETA: 33s - loss: 2.3141 - accuracy: - ETA: 32s - loss: 2.3140 - accuracy: 0. - ETA: 32s - loss: 2.3137 - accuracy - ETA: 31s - loss: 2.3136 - accuracy: 0 - ETA: 31s - los - ETA: 28s - lo - ETA: 25s - loss: 2.3151   - ETA: 4s - loss: 2.3137 - accu - ETA: 1s - loss: 2.3135 - ac - ETA: 0s - loss: 2.313\n",
      "Epoch 2/25\n",
      "468/468 [==============================] - 47s 100ms/step - loss: 2.3147 - accuracy: 0.1049 - val_loss: 2.3122 - val_accuracy: 0.0975 - lo - ETA: 22s - loss: 2.3131 - accuracy: - ETA: 22s - loss: 2.3132 - accura - ETA: 21s - loss: 2.3132 - ac - ETA\n",
      "Epoch 3/25\n",
      "468/468 [==============================] - 51s 109ms/step - loss: 2.3135 - accuracy: 0.1033 - val_loss: 2.3056 - val_accuracy: 0.102935 - accur - ETA: - ETA: 30s -  - ETA: 6s - loss: 2.3135 - accuracy: 0.10 - ETA: 6s - l - ETA:  - ETA: 3s - loss: 2.3134  - ETA: 2s - loss: 2.3135 - accuracy: 0.10 - ETA: 1s - loss: 2.3134 - accura - ETA: 1s - l\n",
      "Epoch 4/25\n",
      "468/468 [==============================] - 49s 106ms/step - loss: 2.3141 - accuracy: 0.1004 - val_loss: 2.3052 - val_accuracy: 0.1010 - loss: 2.3131 - ac - ETA: 17s - loss: 2.313 - ETA: 10s  -\n",
      "Epoch 5/25\n",
      "468/468 [==============================] - 52s 111ms/step - loss: 2.3114 - accuracy: 0.1055 - val_loss: 2.3161 - val_accuracy: 0.1032: 2.3\n",
      "Epoch 6/25\n",
      "468/468 [==============================] - 56s 120ms/step - loss: 2.3128 - accuracy: 0.1040 - val_loss: 2.3084 - val_accuracy: 0.0975\n",
      "Epoch 7/25\n",
      "468/468 [==============================] - 51s 108ms/step - loss: 2.3024 - accuracy: 0.1108 - val_loss: 2.3015 - val_accuracy: 0.1135\n",
      "Epoch 8/25\n",
      "468/468 [==============================] - 55s 119ms/step - loss: 2.3026 - accuracy: 0.1067 - val_loss: 2.3016 - val_accuracy: 0.1135\n",
      "Epoch 9/25\n",
      "468/468 [==============================] - 56s 119ms/step - loss: 2.3024 - accuracy: 0.1105 - val_loss: 2.3013 - val_accuracy: 0.1135\n",
      "Epoch 10/25\n",
      "468/468 [==============================] - 53s 114ms/step - loss: 2.3025 - accuracy: 0.1086 - val_loss: 2.3028 - val_accuracy: 0.1010\n",
      "Epoch 11/25\n",
      "468/468 [==============================] - 53s 114ms/step - loss: 2.3025 - accuracy: 0.1072 - val_loss: 2.3025 - val_accuracy: 0.1029\n",
      "Epoch 12/25\n",
      "468/468 [==============================] - 51s 109ms/step - loss: 2.3024 - accuracy: 0.1083 - val_loss: 2.3020 - val_accuracy: 0.1135\n",
      "Epoch 13/25\n",
      "468/468 [==============================] - 50s 106ms/step - loss: 2.3012 - accuracy: 0.1143 - val_loss: 2.3017 - val_accuracy: 0.1135\n",
      "Epoch 14/25\n",
      "468/468 [==============================] - 51s 110ms/step - loss: 2.3019 - accuracy: 0.1103 - val_loss: 2.3019 - val_accuracy: 0.1135\n",
      "Epoch 15/25\n",
      "468/468 [==============================] - 50s 108ms/step - loss: 2.3018 - accuracy: 0.1128 - val_loss: 2.3018 - val_accuracy: 0.1029\n",
      "Epoch 16/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3021 - accuracy: 0.1106 - val_loss: 2.3016 - val_accuracy: 0.1135\n",
      "Epoch 17/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3018 - accuracy: 0.1108 - val_loss: 2.3016 - val_accuracy: 0.1029\n",
      "Epoch 18/25\n",
      "468/468 [==============================] - 51s 109ms/step - loss: 2.3021 - accuracy: 0.1105 - val_loss: 2.3018 - val_accuracy: 0.1135\n",
      "Epoch 19/25\n",
      "468/468 [==============================] - 51s 109ms/step - loss: 2.3020 - accuracy: 0.1106 - val_loss: 2.3016 - val_accuracy: 0.1135\n",
      "Epoch 20/25\n",
      "468/468 [==============================] - 53s 113ms/step - loss: 2.3018 - accuracy: 0.1106 - val_loss: 2.3016 - val_accuracy: 0.1135\n",
      "Epoch 21/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3016 - accuracy: 0.1147 - val_loss: 2.3014 - val_accuracy: 0.1135\n",
      "Epoch 22/25\n",
      "468/468 [==============================] - 50s 108ms/step - loss: 2.3018 - accuracy: 0.1107 - val_loss: 2.3011 - val_accuracy: 0.1135\n",
      "Epoch 23/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3013 - accuracy: 0.1120 - val_loss: 2.3012 - val_accuracy: 0.1135\n",
      "Epoch 24/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3016 - accuracy: 0.1122 - val_loss: 2.3013 - val_accuracy: 0.1135\n",
      "Epoch 25/25\n",
      "468/468 [==============================] - 50s 107ms/step - loss: 2.3012 - accuracy: 0.1121 - val_loss: 2.3013 - val_accuracy: 0.1135\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x2658e3666c8>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(dataset,epochs=25,\n",
    "          steps_per_epoch=60000//128,\n",
    "          validation_data=test_dataset,\n",
    "          validation_steps=10000//128,\n",
    "          callbacks=[tensorboard_callback,lr_callback]\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The tensorboard extension is already loaded. To reload it, use:\n",
      "  %reload_ext tensorboard\n"
     ]
    }
   ],
   "source": [
    "# 在notebook中显示\n",
    "%load_ext tensorboard\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Reusing TensorBoard on port 6006 (pid 1420), started 6:08:55 ago. (Use '!kill 1420' to kill it.)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-c91439b05bce91e0\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-c91439b05bce91e0\");\n",
       "          const url = new URL(\"/\", window.location);\n",
       "          const port = 6006;\n",
       "          if (port) {\n",
       "            url.port = port;\n",
       "          }\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 启动tensorboard:参数为目录地址，此处目录为logs\n",
    "%tensorboard --logdir logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#在浏览器中打开tensorboard\n",
    "#1.打开cmd命令行，进入代码文件所在路径\n",
    "#2.运行tensorboard --logdir logs命令\n",
    "#3.将显示的网址复制到浏览器打开"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 记录自定义标量值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重新调整回归模型并记录自定义学习率。\n",
    "1.使用创建文件编写器：tf.summary.create_file_writer()\n",
    "2.定义自定义学习率功能，传递给Keras LeariningRateScheduler回调\n",
    "3.在学习率功能内，用于tf.summary.scalar()记录自定义学习率\n",
    "4.将LeariningRateScheduler回调传递给Model.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义训练中使用TensorBoard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizersmizerstimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = tf.keras.losses.SparseCategoricalCrossentropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(model,x,y):\n",
    "    y_ = model(x)\n",
    "    return loss_func(y,y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化计算对象\n",
    "train_loss = tf.keras.metrics.Mean('train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy') \n",
    "test_loss = tf.keras.metrics.Mean('test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(model,images,labels):\n",
    "    with tf.GradientTape() as t:\n",
    "        pred = model(images)\n",
    "        loss_step = loss_func(labels,pred)\n",
    "        #loss_step = loss(model,images,labels) # 每一步的损失值\n",
    "    grads = t.gradient(loss_step,model.trainable_variables) # 损失函数与可训练参数之间的梯度\n",
    "    optimizer.apply_gradients(zip(grads,model.trainable_variables)) # 优化函数应用梯度进行优化\n",
    "    # 汇总计算平均loss\n",
    "    train_loss(loss_step)\n",
    "    # 汇总计算平均acc\n",
    "    train_accuracy(labels,pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_step(model,images,labels):\n",
    "        pred = model(images)\n",
    "        loss_step = loss_func(labels,pred)\n",
    "        # 汇总计算平均loss\n",
    "        test_loss(loss_step)\n",
    "        # 汇总计算平均acc\n",
    "        test_accuracy(labels,pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_log_dir = 'logs/gradient_tape' + current_time + '/train'\n",
    "test_log_dir = 'logs/gradient_tape' + current_time + '/test'\n",
    "\n",
    "train_writer = tf.summary.create_file_writer(train_log_dir)\n",
    "test_writer = tf.summary.create_file_writer(test_log_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    for epoch in range(10):\n",
    "        for (batch,(images,labels)) in enumerate(dataset):\n",
    "            train_step(model,images,labels)\n",
    "        with train_writer.as_default():\n",
    "            tf.summary.scalar('loss',train_loss.result(),step = epoch) # 收集标量值,记录\n",
    "            tf.summary.scalar('acc',train_accuracy.result(),step = epoch)\n",
    "        \n",
    "        print('Epoch{}  loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))\n",
    "        for (batch,(images,labels)) in enumerate(test_dataset):\n",
    "            test_step(model,images,labels)\n",
    "            print('*',end='')\n",
    "            \n",
    "        with test_writer.as_default():\n",
    "            tf.summary.scalar('loss',test_loss.result(),step = epoch) # 收集标量值,记录\n",
    "            tf.summary.scalar('acc',test_accuracy.result(),step = epoch)\n",
    "            \n",
    "        print('Epoch{}  test_loss is {}, test_accuracy is {}'.format(epoch,test_loss.result(),test_accuracy.result()))\n",
    "        # 重置\n",
    "        train_loss.reset_states()\n",
    "        train_accuracy.reset_states()\n",
    "        test_loss.reset_states()\n",
    "        test_accuracy.reset_states()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
